# +-----------------------------------------------+
# |                                               |
# |           Give Feedback / Get Help            |
# | https://github.com/BerriAI/litellm/issues/new |
# |                                               |
# +-----------------------------------------------+
#
#  Thank you ! We ❤️ you! - Krrish & Ishaan

import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache
import datetime as datetime_og
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import (
    CustomHTTPTransport,
    AsyncCustomHTTPTransport,
)
from litellm.utils import (
    ModelResponse,
    CustomStreamWrapper,
    get_utc_datetime,
    calculate_max_parallel_requests,
    _is_region_eu,
)
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import (
    Deployment,
    ModelInfo,
    LiteLLM_Params,
    RouterErrors,
    updateDeployment,
    updateLiteLLMParams,
    RetryPolicy,
    AlertingConfig,
    DeploymentTypedDict,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc


class Router:
    model_names: List = []
    cache_responses: Optional[bool] = False
    default_cache_time_seconds: int = 1 * 60 * 60  # 1 hour
    tenacity = None
    leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
    lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None

    def __init__(
        self,
        model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
        ## CACHING ##
        redis_url: Optional[str] = None,
        redis_host: Optional[str] = None,
        redis_port: Optional[int] = None,
        redis_password: Optional[str] = None,
        cache_responses: Optional[bool] = False,
        cache_kwargs: dict = {},  # additional kwargs to pass to RedisCache (see caching.py)
        caching_groups: Optional[
            List[tuple]
        ] = None,  # if you want to cache across model groups
        client_ttl: int = 3600,  # ttl for cached clients - will re-initialize after this time in seconds
        ## RELIABILITY ##
        num_retries: Optional[int] = None,
        timeout: Optional[float] = None,
        default_litellm_params: Optional[
            dict
        ] = None,  # default params for Router.chat.completion.create
        default_max_parallel_requests: Optional[int] = None,
        set_verbose: bool = False,
        debug_level: Literal["DEBUG", "INFO"] = "INFO",
        default_fallbacks: Optional[
            List[str]
        ] = None,  # generic fallbacks, works across all deployments
        fallbacks: List = [],
        context_window_fallbacks: List = [],
        model_group_alias: Optional[dict] = {},
        enable_pre_call_checks: bool = False,
        retry_after: int = 0,  # min time to wait before retrying a failed request
        retry_policy: Optional[
            RetryPolicy
        ] = None,  # set custom retries for different exceptions
        model_group_retry_policy: Optional[
            Dict[str, RetryPolicy]
        ] = {},  # set custom retry policies based on model group
        allowed_fails: Optional[
            int
        ] = None,  # Number of times a deployment can failbefore being added to cooldown
        cooldown_time: float = 1,  # (seconds) time to cooldown a deployment after failure
        routing_strategy: Literal[
            "simple-shuffle",
            "least-busy",
            "usage-based-routing",
            "latency-based-routing",
            "cost-based-routing",
            "usage-based-routing-v2",
        ] = "simple-shuffle",
        routing_strategy_args: dict = {},  # just for latency-based routing
        semaphore: Optional[asyncio.Semaphore] = None,
        alerting_config: Optional[AlertingConfig] = None,
    ) -> None:
        """
        Initialize the Router class with the given parameters for caching, reliability, and routing strategy.

        Args:
            model_list (Optional[list]): List of models to be used. Defaults to None.
            redis_url (Optional[str]): URL of the Redis server. Defaults to None.
            redis_host (Optional[str]): Hostname of the Redis server. Defaults to None.
            redis_port (Optional[int]): Port of the Redis server. Defaults to None.
            redis_password (Optional[str]): Password of the Redis server. Defaults to None.
            cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False.
            cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
            caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
            client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
            num_retries (int): Number of retries for failed requests. Defaults to 0.
            timeout (Optional[float]): Timeout for requests. Defaults to None.
            default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
            set_verbose (bool): Flag to set verbose mode. Defaults to False.
            debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO".
            fallbacks (List): List of fallback options. Defaults to [].
            context_window_fallbacks (List): List of context window fallback options. Defaults to [].
            enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt
            model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}.
            retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0.
            allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None.
            cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
            routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
            routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
            alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None.
        Returns:
            Router: An instance of the litellm.Router class.

        Example Usage:
        ```python
        from litellm import Router
        model_list = [
        {
            "model_name": "azure-gpt-3.5-turbo", # model alias
            "litellm_params": { # params for litellm completion/embedding call
                "model": "azure/<your-deployment-name-1>",
                "api_key": <your-api-key>,
                "api_version": <your-api-version>,
                "api_base": <your-api-base>
            },
        },
        {
            "model_name": "azure-gpt-3.5-turbo", # model alias
            "litellm_params": { # params for litellm completion/embedding call
                "model": "azure/<your-deployment-name-2>",
                "api_key": <your-api-key>,
                "api_version": <your-api-version>,
                "api_base": <your-api-base>
            },
        },
        {
            "model_name": "openai-gpt-3.5-turbo", # model alias
            "litellm_params": { # params for litellm completion/embedding call
                "model": "gpt-3.5-turbo",
                "api_key": <your-api-key>,
            },
        ]

        router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
        ```
        """

        if semaphore:
            self.semaphore = semaphore
        self.set_verbose = set_verbose
        self.debug_level = debug_level
        self.enable_pre_call_checks = enable_pre_call_checks
        if self.set_verbose == True:
            if debug_level == "INFO":
                verbose_router_logger.setLevel(logging.INFO)
            elif debug_level == "DEBUG":
                verbose_router_logger.setLevel(logging.DEBUG)

        self.deployment_names: List = (
            []
        )  # names of models under litellm_params. ex. azure/chatgpt-v-2
        self.deployment_latency_map = {}
        ### CACHING ###
        cache_type: Literal["local", "redis"] = "local"  # default to an in-memory cache
        redis_cache = None
        cache_config = {}
        self.client_ttl = client_ttl
        if redis_url is not None or (
            redis_host is not None
            and redis_port is not None
            and redis_password is not None
        ):
            cache_type = "redis"

            if redis_url is not None:
                cache_config["url"] = redis_url

            if redis_host is not None:
                cache_config["host"] = redis_host

            if redis_port is not None:
                cache_config["port"] = str(redis_port)  # type: ignore

            if redis_password is not None:
                cache_config["password"] = redis_password

            # Add additional key-value pairs from cache_kwargs
            cache_config.update(cache_kwargs)
            redis_cache = RedisCache(**cache_config)

        if cache_responses:
            if litellm.cache is None:
                # the cache can be initialized on the proxy server. We should not overwrite it
                litellm.cache = litellm.Cache(type=cache_type, **cache_config)  # type: ignore
            self.cache_responses = cache_responses
        self.cache = DualCache(
            redis_cache=redis_cache, in_memory_cache=InMemoryCache()
        )  # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.

        self.default_deployment = None  # use this to track the users default deployment, when they want to use model = *
        self.default_max_parallel_requests = default_max_parallel_requests

        if model_list is not None:
            model_list = copy.deepcopy(model_list)
            self.set_model_list(model_list)
            self.healthy_deployments: List = self.model_list  # type: ignore
            for m in model_list:
                self.deployment_latency_map[m["litellm_params"]["model"]] = 0
        else:
            self.model_list: List = (
                []
            )  # initialize an empty list - to allow _add_deployment and delete_deployment to work

        self.allowed_fails = allowed_fails or litellm.allowed_fails
        self.cooldown_time = cooldown_time or 1
        self.failed_calls = (
            InMemoryCache()
        )  # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown

        if num_retries is not None:
            self.num_retries = num_retries
        elif litellm.num_retries is not None:
            self.num_retries = litellm.num_retries
        else:
            self.num_retries = openai.DEFAULT_MAX_RETRIES

        self.timeout = timeout or litellm.request_timeout

        self.retry_after = retry_after
        self.routing_strategy = routing_strategy

        ## SETTING FALLBACKS ##
        ### validate if it's set + in correct format
        _fallbacks = fallbacks or litellm.fallbacks

        self.validate_fallbacks(fallback_param=_fallbacks)
        ### set fallbacks
        self.fallbacks = _fallbacks

        if default_fallbacks is not None or litellm.default_fallbacks is not None:
            _fallbacks = default_fallbacks or litellm.default_fallbacks
            if self.fallbacks is not None:
                self.fallbacks.append({"*": _fallbacks})
            else:
                self.fallbacks = [{"*": _fallbacks}]

        self.context_window_fallbacks = (
            context_window_fallbacks or litellm.context_window_fallbacks
        )
        self.total_calls: defaultdict = defaultdict(
            int
        )  # dict to store total calls made to each model
        self.fail_calls: defaultdict = defaultdict(
            int
        )  # dict to store fail_calls made to each model
        self.success_calls: defaultdict = defaultdict(
            int
        )  # dict to store success_calls  made to each model
        self.previous_models: List = (
            []
        )  # list to store failed calls (passed in as metadata to next call)
        self.model_group_alias: dict = (
            model_group_alias or {}
        )  # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group

        # make Router.chat.completions.create compatible for openai.chat.completions.create
        default_litellm_params = default_litellm_params or {}
        self.chat = litellm.Chat(params=default_litellm_params, router_obj=self)

        # default litellm args
        self.default_litellm_params = default_litellm_params
        self.default_litellm_params.setdefault("timeout", timeout)
        self.default_litellm_params.setdefault("max_retries", 0)
        self.default_litellm_params.setdefault("metadata", {}).update(
            {"caching_groups": caching_groups}
        )

        self.deployment_stats: dict = {}  # used for debugging load balancing
        """
        deployment_stats = {
            "122999-2828282-277:
            {
                "model": "gpt-3",
                "api_base": "http://localhost:4000",
                "num_requests": 20,
                "avg_latency": 0.001,
                "num_failures": 0,
                "num_successes": 20
            }
        }
        """
        ### ROUTING SETUP ###
        self.routing_strategy_init(
            routing_strategy=routing_strategy,
            routing_strategy_args=routing_strategy_args,
        )
        ## COOLDOWNS ##
        if isinstance(litellm.failure_callback, list):
            litellm.failure_callback.append(self.deployment_callback_on_failure)
        else:
            litellm.failure_callback = [self.deployment_callback_on_failure]
        print(  # noqa
            f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
        )  # noqa
        self.routing_strategy_args = routing_strategy_args
        self.retry_policy: Optional[RetryPolicy] = retry_policy
        self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
            model_group_retry_policy
        )
        self.alerting_config: Optional[AlertingConfig] = alerting_config
        if self.alerting_config is not None:
            self._initialize_alerting()

    def validate_fallbacks(self, fallback_param: Optional[List]):
        if fallback_param is None:
            return
        if len(fallback_param) > 0:  # if set
            ## for dictionary in list, check if only 1 key in dict
            for _dict in fallback_param:
                assert isinstance(_dict, dict), "Item={}, not a dictionary".format(
                    _dict
                )
                assert (
                    len(_dict.keys()) == 1
                ), "Only 1 key allows in dictionary. You set={} for dict={}".format(
                    len(_dict.keys()), _dict
                )

    def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
        if routing_strategy == "least-busy":
            self.leastbusy_logger = LeastBusyLoggingHandler(
                router_cache=self.cache, model_list=self.model_list
            )
            ## add callback
            if isinstance(litellm.input_callback, list):
                litellm.input_callback.append(self.leastbusy_logger)  # type: ignore
            else:
                litellm.input_callback = [self.leastbusy_logger]  # type: ignore
            if isinstance(litellm.callbacks, list):
                litellm.callbacks.append(self.leastbusy_logger)  # type: ignore
        elif routing_strategy == "usage-based-routing":
            self.lowesttpm_logger = LowestTPMLoggingHandler(
                router_cache=self.cache,
                model_list=self.model_list,
                routing_args=routing_strategy_args,
            )
            if isinstance(litellm.callbacks, list):
                litellm.callbacks.append(self.lowesttpm_logger)  # type: ignore
        elif routing_strategy == "usage-based-routing-v2":
            self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
                router_cache=self.cache,
                model_list=self.model_list,
                routing_args=routing_strategy_args,
            )
            if isinstance(litellm.callbacks, list):
                litellm.callbacks.append(self.lowesttpm_logger_v2)  # type: ignore
        elif routing_strategy == "latency-based-routing":
            self.lowestlatency_logger = LowestLatencyLoggingHandler(
                router_cache=self.cache,
                model_list=self.model_list,
                routing_args=routing_strategy_args,
            )
            if isinstance(litellm.callbacks, list):
                litellm.callbacks.append(self.lowestlatency_logger)  # type: ignore
        elif routing_strategy == "cost-based-routing":
            self.lowestcost_logger = LowestCostLoggingHandler(
                router_cache=self.cache,
                model_list=self.model_list,
                routing_args={},
            )
            if isinstance(litellm.callbacks, list):
                litellm.callbacks.append(self.lowestcost_logger)  # type: ignore

    def print_deployment(self, deployment: dict):
        """
        returns a copy of the deployment with the api key masked
        """
        try:
            _deployment_copy = copy.deepcopy(deployment)
            litellm_params: dict = _deployment_copy["litellm_params"]
            if "api_key" in litellm_params:
                litellm_params["api_key"] = litellm_params["api_key"][:2] + "*" * 10
            return _deployment_copy
        except Exception as e:
            verbose_router_logger.debug(
                f"Error occurred while printing deployment - {str(e)}"
            )
            raise e

    ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS

    def completion(
        self, model: str, messages: List[Dict[str, str]], **kwargs
    ) -> Union[ModelResponse, CustomStreamWrapper]:
        """
        Example usage:
        response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
        """
        try:
            verbose_router_logger.debug(f"router.completion(model={model},..)")
            kwargs["model"] = model
            kwargs["messages"] = messages
            kwargs["original_function"] = self._completion
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = self.function_with_fallbacks(**kwargs)
            return response
        except Exception as e:
            raise e

    def _completion(
        self, model: str, messages: List[Dict[str, str]], **kwargs
    ) -> Union[ModelResponse, CustomStreamWrapper]:
        model_name = None
        try:
            # pick the one that is available (lowest TPM/RPM)
            deployment = self.get_available_deployment(
                model=model,
                messages=messages,
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "api_base": deployment.get("litellm_params", {}).get("api_base"),
                    "model_info": deployment.get("model_info", {}),
                }
            )
            data = deployment["litellm_params"].copy()
            kwargs["model_info"] = deployment.get("model_info", {})
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)
            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client

            ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
            self.routing_strategy_pre_call_checks(deployment=deployment)

            response = litellm.completion(
                **{
                    **data,
                    "messages": messages,
                    "caching": self.cache_responses,
                    "client": model_client,
                    **kwargs,
                }
            )
            verbose_router_logger.info(
                f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
            )

            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.completion(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            raise e

    # fmt: off

    @overload
    async def acompletion(
        self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
    ) -> CustomStreamWrapper: 
        ...

    @overload
    async def acompletion(
        self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
    ) -> ModelResponse: 
        ...

    # fmt: on

    # The actual implementation of the function
    async def acompletion(
        self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
    ):
        try:
            kwargs["model"] = model
            kwargs["messages"] = messages
            kwargs["stream"] = stream
            kwargs["original_function"] = self._acompletion
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)

            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})

            response = await self.async_function_with_fallbacks(**kwargs)

            return response
        except Exception as e:
            raise e

    async def _acompletion(
        self, model: str, messages: List[Dict[str, str]], **kwargs
    ) -> Union[ModelResponse, CustomStreamWrapper]:
        """
        - Get an available deployment
        - call it with a semaphore over the call
        - semaphore specific to it's rpm
        - in the semaphore,  make a check against it's local rpm before running
        """
        model_name = None
        try:
            verbose_router_logger.debug(
                f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
            )

            deployment = await self.async_get_available_deployment(
                model=model,
                messages=messages,
                specific_deployment=kwargs.pop("specific_deployment", None),
                request_kwargs=kwargs,
            )

            # debug how often this deployment picked
            self._track_deployment_metrics(deployment=deployment)

            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                    "api_base": deployment.get("litellm_params", {}).get("api_base"),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()

            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs and v is not None
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )

            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client
            self.total_calls[model_name] += 1

            timeout = (
                data.get(
                    "timeout", None
                )  # timeout set on litellm_params for this deployment
                or self.timeout  # timeout set on router
                or kwargs.get(
                    "timeout", None
                )  # this uses default_litellm_params when nothing is set
            )

            _response = litellm.acompletion(
                **{
                    **data,
                    "messages": messages,
                    "caching": self.cache_responses,
                    "client": model_client,
                    "timeout": timeout,
                    **kwargs,
                }
            )

            rpm_semaphore = self._get_client(
                deployment=deployment,
                kwargs=kwargs,
                client_type="max_parallel_requests",
            )

            if rpm_semaphore is not None and isinstance(
                rpm_semaphore, asyncio.Semaphore
            ):
                async with rpm_semaphore:
                    """
                    - Check rpm limits before making the call
                    - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
                    """
                    await self.async_routing_strategy_pre_call_checks(
                        deployment=deployment
                    )
                    response = await _response
            else:
                await self.async_routing_strategy_pre_call_checks(deployment=deployment)
                response = await _response

            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
            )
            # debug how often this deployment picked
            self._track_deployment_metrics(deployment=deployment, response=response)

            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.acompletion(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    async def abatch_completion(
        self, models: List[str], messages: List[Dict[str, str]], **kwargs
    ):
        """
        Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router
        Use this for sending the same request to N models
        """

        async def _async_completion_no_exceptions(
            model: str, messages: List[Dict[str, str]], **kwargs
        ):
            """
            Wrapper around self.async_completion that catches exceptions and returns them as a result
            """
            try:
                return await self.acompletion(model=model, messages=messages, **kwargs)
            except Exception as e:
                return e

        _tasks = []
        for model in models:
            # add each task but if the task fails
            _tasks.append(
                _async_completion_no_exceptions(
                    model=model, messages=messages, **kwargs
                )
            )

        response = await asyncio.gather(*_tasks)
        return response

    async def abatch_completion_one_model_multiple_requests(
        self, model: str, messages: List[List[Dict[str, str]]], **kwargs
    ):
        """
        Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router

        Use this for sending multiple requests to 1 model

        Args:
            model (List[str]): model group
            messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request
            **kwargs: additional kwargs
        Usage:
            response = await self.abatch_completion_one_model_multiple_requests(
                model="gpt-3.5-turbo",
                messages=[
                    [{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}],
                    [{"role": "user", "content": "hello good mornign"}],
                ]
            )
        """

        async def _async_completion_no_exceptions(
            model: str, messages: List[Dict[str, str]], **kwargs
        ):
            """
            Wrapper around self.async_completion that catches exceptions and returns them as a result
            """
            try:
                return await self.acompletion(model=model, messages=messages, **kwargs)
            except Exception as e:
                return e

        _tasks = []
        for message_request in messages:
            # add each task but if the task fails
            _tasks.append(
                _async_completion_no_exceptions(
                    model=model, messages=message_request, **kwargs
                )
            )

        response = await asyncio.gather(*_tasks)
        return response

    def image_generation(self, prompt: str, model: str, **kwargs):
        try:
            kwargs["model"] = model
            kwargs["prompt"] = prompt
            kwargs["original_function"] = self._image_generation
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = self.function_with_fallbacks(**kwargs)

            return response
        except Exception as e:
            raise e

    def _image_generation(self, prompt: str, model: str, **kwargs):
        try:
            verbose_router_logger.debug(
                f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
            )
            deployment = self.get_available_deployment(
                model=model,
                messages=[{"role": "user", "content": "prompt"}],
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client

            self.total_calls[model_name] += 1

            ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
            self.routing_strategy_pre_call_checks(deployment=deployment)

            response = litellm.image_generation(
                **{
                    **data,
                    "prompt": prompt,
                    "caching": self.cache_responses,
                    "client": model_client,
                    **kwargs,
                }
            )
            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.image_generation(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.image_generation(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    async def aimage_generation(self, prompt: str, model: str, **kwargs):
        try:
            kwargs["model"] = model
            kwargs["prompt"] = prompt
            kwargs["original_function"] = self._aimage_generation
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = await self.async_function_with_fallbacks(**kwargs)

            return response
        except Exception as e:
            raise e

    async def _aimage_generation(self, prompt: str, model: str, **kwargs):
        try:
            verbose_router_logger.debug(
                f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
            )
            deployment = await self.async_get_available_deployment(
                model=model,
                messages=[{"role": "user", "content": "prompt"}],
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client

            self.total_calls[model_name] += 1
            response = litellm.aimage_generation(
                **{
                    **data,
                    "prompt": prompt,
                    "caching": self.cache_responses,
                    "client": model_client,
                    **kwargs,
                }
            )

            ### CONCURRENCY-SAFE RPM CHECKS ###
            rpm_semaphore = self._get_client(
                deployment=deployment,
                kwargs=kwargs,
                client_type="max_parallel_requests",
            )

            if rpm_semaphore is not None and isinstance(
                rpm_semaphore, asyncio.Semaphore
            ):
                async with rpm_semaphore:
                    """
                    - Check rpm limits before making the call
                    - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
                    """
                    await self.async_routing_strategy_pre_call_checks(
                        deployment=deployment
                    )
                    response = await response
            else:
                await self.async_routing_strategy_pre_call_checks(deployment=deployment)
                response = await response

            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.aimage_generation(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    async def atranscription(self, file: BinaryIO, model: str, **kwargs):
        """
        Example Usage:

        ```
        from litellm import Router
        client = Router(model_list = [
            {
                "model_name": "whisper",
                "litellm_params": {
                    "model": "whisper-1",
                },
            },
        ])

        audio_file = open("speech.mp3", "rb")
        transcript = await client.atranscription(
        model="whisper",
        file=audio_file
        )

        ```
        """
        try:
            kwargs["model"] = model
            kwargs["file"] = file
            kwargs["original_function"] = self._atranscription
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = await self.async_function_with_fallbacks(**kwargs)

            return response
        except Exception as e:
            raise e

    async def _atranscription(self, file: BinaryIO, model: str, **kwargs):
        try:
            verbose_router_logger.debug(
                f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
            )
            deployment = await self.async_get_available_deployment(
                model=model,
                messages=[{"role": "user", "content": "prompt"}],
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client

            self.total_calls[model_name] += 1
            response = litellm.atranscription(
                **{
                    **data,
                    "file": file,
                    "caching": self.cache_responses,
                    "client": model_client,
                    **kwargs,
                }
            )

            ### CONCURRENCY-SAFE RPM CHECKS ###
            rpm_semaphore = self._get_client(
                deployment=deployment,
                kwargs=kwargs,
                client_type="max_parallel_requests",
            )

            if rpm_semaphore is not None and isinstance(
                rpm_semaphore, asyncio.Semaphore
            ):
                async with rpm_semaphore:
                    """
                    - Check rpm limits before making the call
                    - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
                    """
                    await self.async_routing_strategy_pre_call_checks(
                        deployment=deployment
                    )
                    response = await response
            else:
                await self.async_routing_strategy_pre_call_checks(deployment=deployment)
                response = await response

            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    async def amoderation(self, model: str, input: str, **kwargs):
        try:
            kwargs["model"] = model
            kwargs["input"] = input
            kwargs["original_function"] = self._amoderation
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})

            response = await self.async_function_with_fallbacks(**kwargs)

            return response
        except Exception as e:
            raise e

    async def _amoderation(self, model: str, input: str, **kwargs):
        model_name = None
        try:
            verbose_router_logger.debug(
                f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
            )
            deployment = await self.async_get_available_deployment(
                model=model,
                input=input,
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs and v is not None
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client
            self.total_calls[model_name] += 1

            timeout = (
                data.get(
                    "timeout", None
                )  # timeout set on litellm_params for this deployment
                or self.timeout  # timeout set on router
                or kwargs.get(
                    "timeout", None
                )  # this uses default_litellm_params when nothing is set
            )

            response = await litellm.amoderation(
                **{
                    **data,
                    "input": input,
                    "caching": self.cache_responses,
                    "client": model_client,
                    "timeout": timeout,
                    **kwargs,
                }
            )

            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    def text_completion(
        self,
        model: str,
        prompt: str,
        is_retry: Optional[bool] = False,
        is_fallback: Optional[bool] = False,
        is_async: Optional[bool] = False,
        **kwargs,
    ):
        try:
            kwargs["model"] = model
            kwargs["prompt"] = prompt
            kwargs["original_function"] = self._acompletion
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})

            messages = [{"role": "user", "content": prompt}]
            # pick the one that is available (lowest TPM/RPM)
            deployment = self.get_available_deployment(
                model=model,
                messages=messages,
                specific_deployment=kwargs.pop("specific_deployment", None),
            )

            data = deployment["litellm_params"].copy()
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            # call via litellm.completion()
            return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs})  # type: ignore
        except Exception as e:
            if self.num_retries > 0:
                kwargs["model"] = model
                kwargs["messages"] = messages
                kwargs["original_function"] = self.completion
                return self.function_with_retries(**kwargs)
            else:
                raise e

    async def atext_completion(
        self,
        model: str,
        prompt: str,
        is_retry: Optional[bool] = False,
        is_fallback: Optional[bool] = False,
        is_async: Optional[bool] = False,
        **kwargs,
    ):
        try:
            kwargs["model"] = model
            kwargs["prompt"] = prompt
            kwargs["original_function"] = self._atext_completion
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = await self.async_function_with_fallbacks(**kwargs)

            return response
        except Exception as e:
            raise e

    async def _atext_completion(self, model: str, prompt: str, **kwargs):
        try:
            verbose_router_logger.debug(
                f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
            )
            deployment = await self.async_get_available_deployment(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                    "api_base": deployment.get("litellm_params", {}).get("api_base"),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client
            self.total_calls[model_name] += 1

            response = litellm.atext_completion(
                **{
                    **data,
                    "prompt": prompt,
                    "caching": self.cache_responses,
                    "client": model_client,
                    "timeout": self.timeout,
                    **kwargs,
                }
            )

            rpm_semaphore = self._get_client(
                deployment=deployment,
                kwargs=kwargs,
                client_type="max_parallel_requests",
            )

            if rpm_semaphore is not None and isinstance(
                rpm_semaphore, asyncio.Semaphore
            ):
                async with rpm_semaphore:
                    """
                    - Check rpm limits before making the call
                    - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
                    """
                    await self.async_routing_strategy_pre_call_checks(
                        deployment=deployment
                    )
                    response = await response
            else:
                await self.async_routing_strategy_pre_call_checks(deployment=deployment)
                response = await response

            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m"
            )
            if model is not None:
                self.fail_calls[model] += 1
            raise e

    def embedding(
        self,
        model: str,
        input: Union[str, List],
        is_async: Optional[bool] = False,
        **kwargs,
    ) -> Union[List[float], None]:
        try:
            kwargs["model"] = model
            kwargs["input"] = input
            kwargs["original_function"] = self._embedding
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = self.function_with_fallbacks(**kwargs)
            return response
        except Exception as e:
            raise e

    def _embedding(self, input: Union[str, List], model: str, **kwargs):
        try:
            verbose_router_logger.debug(
                f"Inside embedding()- model: {model}; kwargs: {kwargs}"
            )
            deployment = self.get_available_deployment(
                model=model,
                input=input,
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="sync"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client

            self.total_calls[model_name] += 1

            ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
            self.routing_strategy_pre_call_checks(deployment=deployment)

            response = litellm.embedding(
                **{
                    **data,
                    "input": input,
                    "caching": self.cache_responses,
                    "client": model_client,
                    **kwargs,
                }
            )
            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    async def aembedding(
        self,
        model: str,
        input: Union[str, List],
        is_async: Optional[bool] = True,
        **kwargs,
    ) -> Union[List[float], None]:
        try:
            kwargs["model"] = model
            kwargs["input"] = input
            kwargs["original_function"] = self._aembedding
            kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
            timeout = kwargs.get("request_timeout", self.timeout)
            kwargs.setdefault("metadata", {}).update({"model_group": model})
            response = await self.async_function_with_fallbacks(**kwargs)
            return response
        except Exception as e:
            raise e

    async def _aembedding(self, input: Union[str, List], model: str, **kwargs):
        model_name = None
        try:
            verbose_router_logger.debug(
                f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
            )
            deployment = await self.async_get_available_deployment(
                model=model,
                input=input,
                specific_deployment=kwargs.pop("specific_deployment", None),
            )
            kwargs.setdefault("metadata", {}).update(
                {
                    "deployment": deployment["litellm_params"]["model"],
                    "model_info": deployment.get("model_info", {}),
                    "api_base": deployment.get("litellm_params", {}).get("api_base"),
                }
            )
            kwargs["model_info"] = deployment.get("model_info", {})
            data = deployment["litellm_params"].copy()
            model_name = data["model"]
            for k, v in self.default_litellm_params.items():
                if (
                    k not in kwargs
                ):  # prioritize model-specific params > default router params
                    kwargs[k] = v
                elif k == "metadata":
                    kwargs[k].update(v)

            potential_model_client = self._get_client(
                deployment=deployment, kwargs=kwargs, client_type="async"
            )
            # check if provided keys == client keys #
            dynamic_api_key = kwargs.get("api_key", None)
            if (
                dynamic_api_key is not None
                and potential_model_client is not None
                and dynamic_api_key != potential_model_client.api_key
            ):
                model_client = None
            else:
                model_client = potential_model_client

            self.total_calls[model_name] += 1
            response = litellm.aembedding(
                **{
                    **data,
                    "input": input,
                    "caching": self.cache_responses,
                    "client": model_client,
                    **kwargs,
                }
            )

            ### CONCURRENCY-SAFE RPM CHECKS ###
            rpm_semaphore = self._get_client(
                deployment=deployment,
                kwargs=kwargs,
                client_type="max_parallel_requests",
            )

            if rpm_semaphore is not None and isinstance(
                rpm_semaphore, asyncio.Semaphore
            ):
                async with rpm_semaphore:
                    """
                    - Check rpm limits before making the call
                    - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
                    """
                    await self.async_routing_strategy_pre_call_checks(
                        deployment=deployment
                    )
                    response = await response
            else:
                await self.async_routing_strategy_pre_call_checks(deployment=deployment)
                response = await response

            self.success_calls[model_name] += 1
            verbose_router_logger.info(
                f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
            )
            return response
        except Exception as e:
            verbose_router_logger.info(
                f"litellm.aembedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
            )
            if model_name is not None:
                self.fail_calls[model_name] += 1
            raise e

    async def async_function_with_fallbacks(self, *args, **kwargs):
        """
        Try calling the function_with_retries
        If it fails after num_retries, fall back to another model group
        """
        mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
        model_group = kwargs.get("model")
        fallbacks = kwargs.get("fallbacks", self.fallbacks)
        context_window_fallbacks = kwargs.get(
            "context_window_fallbacks", self.context_window_fallbacks
        )
        try:
            if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
                raise Exception(
                    f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
                )

            response = await self.async_function_with_retries(*args, **kwargs)
            verbose_router_logger.debug(f"Async Response: {response}")
            return response
        except Exception as e:
            verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
            original_exception = e
            fallback_model_group = None
            try:
                verbose_router_logger.debug(f"Trying to fallback b/w models")
                if (
                    hasattr(e, "status_code")
                    and e.status_code == 400  # type: ignore
                    and not isinstance(e, litellm.ContextWindowExceededError)
                ):  # don't retry a malformed request
                    raise e
                if (
                    isinstance(e, litellm.ContextWindowExceededError)
                    and context_window_fallbacks is not None
                ):
                    fallback_model_group = None
                    for (
                        item
                    ) in context_window_fallbacks:  # [{"gpt-3.5-turbo": ["gpt-4"]}]
                        if list(item.keys())[0] == model_group:
                            fallback_model_group = item[model_group]
                            break

                    if fallback_model_group is None:
                        raise original_exception

                    for mg in fallback_model_group:
                        """
                        Iterate through the model groups and try calling that deployment
                        """
                        try:
                            kwargs["model"] = mg
                            kwargs.setdefault("metadata", {}).update(
                                {"model_group": mg}
                            )  # update model_group used, if fallbacks are done
                            response = await self.async_function_with_retries(
                                *args, **kwargs
                            )
                            verbose_router_logger.info(
                                "Successful fallback b/w models."
                            )
                            return response
                        except Exception as e:
                            pass
                elif fallbacks is not None:
                    verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
                    generic_fallback_idx: Optional[int] = None
                    ## check for specific model group-specific fallbacks
                    for idx, item in enumerate(fallbacks):
                        if list(item.keys())[0] == model_group:
                            fallback_model_group = item[model_group]
                            break
                        elif list(item.keys())[0] == "*":
                            generic_fallback_idx = idx
                    ## if none, check for generic fallback
                    if (
                        fallback_model_group is None
                        and generic_fallback_idx is not None
                    ):
                        fallback_model_group = fallbacks[generic_fallback_idx]["*"]

                    if fallback_model_group is None:
                        verbose_router_logger.info(
                            f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
                        )
                        raise original_exception
                    for mg in fallback_model_group:
                        """
                        Iterate through the model groups and try calling that deployment
                        """
                        try:
                            ## LOGGING
                            kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
                            verbose_router_logger.info(
                                f"Falling back to model_group = {mg}"
                            )
                            kwargs["model"] = mg
                            kwargs.setdefault("metadata", {}).update(
                                {"model_group": mg}
                            )  # update model_group used, if fallbacks are done
                            response = await self.async_function_with_fallbacks(
                                *args, **kwargs
                            )
                            verbose_router_logger.info(
                                "Successful fallback b/w models."
                            )
                            return response
                        except Exception as e:
                            raise e
            except Exception as e:
                verbose_router_logger.debug(f"An exception occurred - {str(e)}")
                traceback.print_exc()
            raise original_exception

    async def async_function_with_retries(self, *args, **kwargs):
        verbose_router_logger.debug(
            f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
        )
        original_function = kwargs.pop("original_function")
        fallbacks = kwargs.pop("fallbacks", self.fallbacks)
        context_window_fallbacks = kwargs.pop(
            "context_window_fallbacks", self.context_window_fallbacks
        )

        num_retries = kwargs.pop("num_retries")

        verbose_router_logger.debug(
            f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
        )
        try:
            # if the function call is successful, no exception will be raised and we'll break out of the loop
            response = await original_function(*args, **kwargs)
            return response
        except Exception as e:
            original_exception = e
            """
            Retry Logic
             
            """
            _healthy_deployments = await self._async_get_healthy_deployments(
                model=kwargs.get("model") or "",
            )

            # raises an exception if this error should not be retries
            self.should_retry_this_error(
                error=e,
                healthy_deployments=_healthy_deployments,
                context_window_fallbacks=context_window_fallbacks,
            )

            # decides how long to sleep before retry
            _timeout = self._time_to_sleep_before_retry(
                e=original_exception,
                remaining_retries=num_retries,
                num_retries=num_retries,
                healthy_deployments=_healthy_deployments,
            )

            # sleeps for the length of the timeout
            await asyncio.sleep(_timeout)

            if (
                self.retry_policy is not None
                or self.model_group_retry_policy is not None
            ):
                # get num_retries from retry policy
                _retry_policy_retries = self.get_num_retries_from_retry_policy(
                    exception=original_exception, model_group=kwargs.get("model")
                )
                if _retry_policy_retries is not None:
                    num_retries = _retry_policy_retries
            ## LOGGING
            if num_retries > 0:
                kwargs = self.log_retry(kwargs=kwargs, e=original_exception)

            for current_attempt in range(num_retries):
                verbose_router_logger.debug(
                    f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}"
                )
                try:
                    # if the function call is successful, no exception will be raised and we'll break out of the loop
                    response = await original_function(*args, **kwargs)
                    if inspect.iscoroutinefunction(
                        response
                    ):  # async errors are often returned as coroutines
                        response = await response
                    return response

                except Exception as e:
                    ## LOGGING
                    kwargs = self.log_retry(kwargs=kwargs, e=e)
                    remaining_retries = num_retries - current_attempt
                    _healthy_deployments = await self._async_get_healthy_deployments(
                        model=kwargs.get("model"),
                    )
                    _timeout = self._time_to_sleep_before_retry(
                        e=original_exception,
                        remaining_retries=remaining_retries,
                        num_retries=num_retries,
                        healthy_deployments=_healthy_deployments,
                    )
                    await asyncio.sleep(_timeout)
            try:
                original_exception.message += f"\nNumber Retries = {current_attempt}"
            except:
                pass
            raise original_exception

    def should_retry_this_error(
        self,
        error: Exception,
        healthy_deployments: Optional[List] = None,
        context_window_fallbacks: Optional[List] = None,
    ):
        """
        1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None

        2. raise an exception for RateLimitError if
            - there are no fallbacks
            - there are no healthy deployments in the same model group
        """

        _num_healthy_deployments = 0
        if healthy_deployments is not None and isinstance(healthy_deployments, list):
            _num_healthy_deployments = len(healthy_deployments)

        ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
        if (
            isinstance(error, litellm.ContextWindowExceededError)
            and context_window_fallbacks is None
        ):
            raise error

        # Error we should only retry if there are other deployments
        if isinstance(error, openai.RateLimitError) or isinstance(
            error, openai.AuthenticationError
        ):
            if _num_healthy_deployments <= 0:
                raise error

        return True

    def function_with_fallbacks(self, *args, **kwargs):
        """
        Try calling the function_with_retries
        If it fails after num_retries, fall back to another model group
        """
        mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
        model_group = kwargs.get("model")
        fallbacks = kwargs.get("fallbacks", self.fallbacks)
        context_window_fallbacks = kwargs.get(
            "context_window_fallbacks", self.context_window_fallbacks
        )
        try:
            if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
                raise Exception(
                    f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
                )

            response = self.function_with_retries(*args, **kwargs)
            return response
        except Exception as e:
            original_exception = e
            verbose_router_logger.debug(f"An exception occurs {original_exception}")
            try:
                if (
                    hasattr(e, "status_code")
                    and e.status_code == 400  # type: ignore
                    and not isinstance(e, litellm.ContextWindowExceededError)
                ):  # don't retry a malformed request
                    raise e

                verbose_router_logger.debug(
                    f"Trying to fallback b/w models. Initial model group: {model_group}"
                )
                if (
                    isinstance(e, litellm.ContextWindowExceededError)
                    and context_window_fallbacks is not None
                ):
                    fallback_model_group = None

                    for (
                        item
                    ) in context_window_fallbacks:  # [{"gpt-3.5-turbo": ["gpt-4"]}]
                        if list(item.keys())[0] == model_group:
                            fallback_model_group = item[model_group]
                            break

                    if fallback_model_group is None:
                        raise original_exception

                    for mg in fallback_model_group:
                        """
                        Iterate through the model groups and try calling that deployment
                        """
                        try:
                            ## LOGGING
                            kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
                            kwargs["model"] = mg
                            kwargs.setdefault("metadata", {}).update(
                                {"model_group": mg}
                            )  # update model_group used, if fallbacks are done
                            response = self.function_with_fallbacks(*args, **kwargs)
                            return response
                        except Exception as e:
                            pass
                elif fallbacks is not None:
                    verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
                    fallback_model_group = None
                    generic_fallback_idx: Optional[int] = None
                    ## check for specific model group-specific fallbacks
                    for idx, item in enumerate(fallbacks):
                        if list(item.keys())[0] == model_group:
                            fallback_model_group = item[model_group]
                            break
                        elif list(item.keys())[0] == "*":
                            generic_fallback_idx = idx
                    ## if none, check for generic fallback
                    if (
                        fallback_model_group is None
                        and generic_fallback_idx is not None
                    ):
                        fallback_model_group = fallbacks[generic_fallback_idx]["*"]

                    if fallback_model_group is None:
                        raise original_exception

                    for mg in fallback_model_group:
                        """
                        Iterate through the model groups and try calling that deployment
                        """
                        try:
                            ## LOGGING
                            kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
                            kwargs["model"] = mg
                            kwargs.setdefault("metadata", {}).update(
                                {"model_group": mg}
                            )  # update model_group used, if fallbacks are done
                            response = self.function_with_fallbacks(*args, **kwargs)
                            return response
                        except Exception as e:
                            raise e
            except Exception as e:
                raise e
            raise original_exception

    def _time_to_sleep_before_retry(
        self,
        e: Exception,
        remaining_retries: int,
        num_retries: int,
        healthy_deployments: Optional[List] = None,
    ) -> Union[int, float]:
        """
        Calculate back-off, then retry

        It should instantly retry only when:
            1. there are healthy deployments in the same model group
            2. there are fallbacks for the completion call
        """
        if (
            healthy_deployments is not None
            and isinstance(healthy_deployments, list)
            and len(healthy_deployments) > 0
        ):
            return 0

        if hasattr(e, "response") and hasattr(e.response, "headers"):
            timeout = litellm._calculate_retry_after(
                remaining_retries=remaining_retries,
                max_retries=num_retries,
                response_headers=e.response.headers,
                min_timeout=self.retry_after,
            )
        else:
            timeout = litellm._calculate_retry_after(
                remaining_retries=remaining_retries,
                max_retries=num_retries,
                min_timeout=self.retry_after,
            )
        return timeout

    def function_with_retries(self, *args, **kwargs):
        """
        Try calling the model 3 times. Shuffle between available deployments.
        """
        verbose_router_logger.debug(
            f"Inside function with retries: args - {args}; kwargs - {kwargs}"
        )
        original_function = kwargs.pop("original_function")
        num_retries = kwargs.pop("num_retries")
        fallbacks = kwargs.pop("fallbacks", self.fallbacks)
        context_window_fallbacks = kwargs.pop(
            "context_window_fallbacks", self.context_window_fallbacks
        )

        try:
            # if the function call is successful, no exception will be raised and we'll break out of the loop
            response = original_function(*args, **kwargs)
            return response
        except Exception as e:
            original_exception = e
            ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
            _healthy_deployments = self._get_healthy_deployments(
                model=kwargs.get("model"),
            )

            # raises an exception if this error should not be retries
            self.should_retry_this_error(
                error=e,
                healthy_deployments=_healthy_deployments,
                context_window_fallbacks=context_window_fallbacks,
            )

            # decides how long to sleep before retry
            _timeout = self._time_to_sleep_before_retry(
                e=original_exception,
                remaining_retries=num_retries,
                num_retries=num_retries,
                healthy_deployments=_healthy_deployments,
            )

            ## LOGGING
            if num_retries > 0:
                kwargs = self.log_retry(kwargs=kwargs, e=original_exception)

            time.sleep(_timeout)
            for current_attempt in range(num_retries):
                verbose_router_logger.debug(
                    f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
                )
                try:
                    # if the function call is successful, no exception will be raised and we'll break out of the loop
                    response = original_function(*args, **kwargs)
                    return response

                except Exception as e:
                    ## LOGGING
                    kwargs = self.log_retry(kwargs=kwargs, e=e)
                    _healthy_deployments = self._get_healthy_deployments(
                        model=kwargs.get("model"),
                    )
                    remaining_retries = num_retries - current_attempt
                    _timeout = self._time_to_sleep_before_retry(
                        e=e,
                        remaining_retries=remaining_retries,
                        num_retries=num_retries,
                        healthy_deployments=_healthy_deployments,
                    )
                    time.sleep(_timeout)
            raise original_exception

    ### HELPER FUNCTIONS

    def deployment_callback_on_failure(
        self,
        kwargs,  # kwargs to completion
        completion_response,  # response from completion
        start_time,
        end_time,  # start/end time
    ):
        try:
            exception = kwargs.get("exception", None)
            exception_type = type(exception)
            exception_status = getattr(exception, "status_code", "")
            exception_cause = getattr(exception, "__cause__", "")
            exception_message = getattr(exception, "message", "")
            exception_str = (
                str(exception_type)
                + "Status: "
                + str(exception_status)
                + "Message: "
                + str(exception_cause)
                + str(exception_message)
                + "Full exception"
                + str(exception)
            )
            model_name = kwargs.get("model", None)  # i.e. gpt35turbo
            custom_llm_provider = kwargs.get("litellm_params", {}).get(
                "custom_llm_provider", None
            )  # i.e. azure
            metadata = kwargs.get("litellm_params", {}).get("metadata", None)
            _model_info = kwargs.get("litellm_params", {}).get("model_info", {})

            exception_response = getattr(exception, "response", {})
            exception_headers = getattr(exception_response, "headers", None)
            _time_to_cooldown = self.cooldown_time

            if exception_headers is not None:

                _time_to_cooldown = (
                    litellm.utils._get_retry_after_from_exception_header(
                        response_headers=exception_headers
                    )
                )

                if _time_to_cooldown < 0:
                    # if the response headers did not read it -> set to default cooldown time
                    _time_to_cooldown = self.cooldown_time

            if isinstance(_model_info, dict):
                deployment_id = _model_info.get("id", None)
                self._set_cooldown_deployments(
                    exception_status=exception_status,
                    deployment=deployment_id,
                    time_to_cooldown=_time_to_cooldown,
                )  # setting deployment_id in cooldown deployments
            if custom_llm_provider:
                model_name = f"{custom_llm_provider}/{model_name}"

        except Exception as e:
            raise e

    def log_retry(self, kwargs: dict, e: Exception) -> dict:
        """
        When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
        """
        try:
            # Log failed model as the previous model
            previous_model = {
                "exception_type": type(e).__name__,
                "exception_string": str(e),
            }
            for (
                k,
                v,
            ) in (
                kwargs.items()
            ):  # log everything in kwargs except the old previous_models value - prevent nesting
                if k not in ["metadata", "messages", "original_function"]:
                    previous_model[k] = v
                elif k == "metadata" and isinstance(v, dict):
                    previous_model["metadata"] = {}  # type: ignore
                    for metadata_k, metadata_v in kwargs["metadata"].items():
                        if metadata_k != "previous_models":
                            previous_model[k][metadata_k] = metadata_v  # type: ignore

            # check current size of self.previous_models, if it's larger than 3, remove the first element
            if len(self.previous_models) > 3:
                self.previous_models.pop(0)

            self.previous_models.append(previous_model)
            kwargs["metadata"]["previous_models"] = self.previous_models
            return kwargs
        except Exception as e:
            raise e

    def _update_usage(self, deployment_id: str):
        """
        Update deployment rpm for that minute
        """
        rpm_key = deployment_id

        request_count = self.cache.get_cache(key=rpm_key, local_only=True)
        if request_count is None:
            request_count = 1
            self.cache.set_cache(
                key=rpm_key, value=request_count, local_only=True, ttl=60
            )  # only store for 60s
        else:
            request_count += 1
            self.cache.set_cache(
                key=rpm_key, value=request_count, local_only=True
            )  # don't change existing ttl

    def _is_cooldown_required(self, exception_status: Union[str, int]):
        """
        A function to determine if a cooldown is required based on the exception status.

        Parameters:
            exception_status (Union[str, int]): The status of the exception.

        Returns:
            bool: True if a cooldown is required, False otherwise.
        """
        try:

            if isinstance(exception_status, str):
                exception_status = int(exception_status)

            if exception_status >= 400 and exception_status < 500:
                if exception_status == 429:
                    # Cool down 429 Rate Limit Errors
                    return True

                elif exception_status == 401:
                    # Cool down 401 Auth Errors
                    return True

                elif exception_status == 408:
                    return True

                else:
                    # Do NOT cool down all other 4XX Errors
                    return False

            else:
                # should cool down for all other errors
                return True

        except:
            # Catch all - if any exceptions default to cooling down
            return True

    def _set_cooldown_deployments(
        self,
        exception_status: Union[str, int],
        deployment: Optional[str] = None,
        time_to_cooldown: Optional[float] = None,
    ):
        """
        Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute

        or

        the exception is not one that should be immediately retried (e.g. 401)
        """
        if deployment is None:
            return

        if self._is_cooldown_required(exception_status=exception_status) == False:
            return

        dt = get_utc_datetime()
        current_minute = dt.strftime("%H-%M")
        # get current fails for deployment
        # update the number of failed calls
        # if it's > allowed fails
        # cooldown deployment
        current_fails = self.failed_calls.get_cache(key=deployment) or 0
        updated_fails = current_fails + 1
        verbose_router_logger.debug(
            f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
        )
        cooldown_time = self.cooldown_time or 1
        if time_to_cooldown is not None:
            cooldown_time = time_to_cooldown

        if isinstance(exception_status, str):
            try:
                exception_status = int(exception_status)
            except Exception as e:
                verbose_router_logger.debug(
                    "Unable to cast exception status to int {}. Defaulting to status=500.".format(
                        exception_status
                    )
                )
                exception_status = 500
        _should_retry = litellm._should_retry(status_code=exception_status)

        if updated_fails > self.allowed_fails or _should_retry == False:
            # get the current cooldown list for that minute
            cooldown_key = f"{current_minute}:cooldown_models"  # group cooldown models by minute to reduce number of redis calls
            cached_value = self.cache.get_cache(key=cooldown_key)

            verbose_router_logger.debug(f"adding {deployment} to cooldown models")
            # update value
            try:
                if deployment in cached_value:
                    pass
                else:
                    cached_value = cached_value + [deployment]
                    # save updated value
                    self.cache.set_cache(
                        value=cached_value, key=cooldown_key, ttl=cooldown_time
                    )
            except:
                cached_value = [deployment]
                # save updated value
                self.cache.set_cache(
                    value=cached_value, key=cooldown_key, ttl=cooldown_time
                )

            self.send_deployment_cooldown_alert(
                deployment_id=deployment,
                exception_status=exception_status,
                cooldown_time=cooldown_time,
            )
        else:
            self.failed_calls.set_cache(
                key=deployment, value=updated_fails, ttl=cooldown_time
            )

    async def _async_get_cooldown_deployments(self):
        """
        Async implementation of '_get_cooldown_deployments'
        """
        dt = get_utc_datetime()
        current_minute = dt.strftime("%H-%M")
        # get the current cooldown list for that minute
        cooldown_key = f"{current_minute}:cooldown_models"

        # ----------------------
        # Return cooldown models
        # ----------------------
        cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []

        verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
        return cooldown_models

    def _get_cooldown_deployments(self):
        """
        Get the list of models being cooled down for this minute
        """
        dt = get_utc_datetime()
        current_minute = dt.strftime("%H-%M")
        # get the current cooldown list for that minute
        cooldown_key = f"{current_minute}:cooldown_models"

        # ----------------------
        # Return cooldown models
        # ----------------------
        cooldown_models = self.cache.get_cache(key=cooldown_key) or []

        verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
        return cooldown_models

    def _get_healthy_deployments(self, model: str):
        _all_deployments: list = []
        try:
            _, _all_deployments = self._common_checks_available_deployment(  # type: ignore
                model=model,
            )
            if type(_all_deployments) == dict:
                return []
        except:
            pass

        unhealthy_deployments = self._get_cooldown_deployments()
        healthy_deployments: list = []
        for deployment in _all_deployments:
            if deployment["model_info"]["id"] in unhealthy_deployments:
                continue
            else:
                healthy_deployments.append(deployment)

        return healthy_deployments

    async def _async_get_healthy_deployments(self, model: str):
        _all_deployments: list = []
        try:
            _, _all_deployments = self._common_checks_available_deployment(  # type: ignore
                model=model,
            )
            if type(_all_deployments) == dict:
                return []
        except:
            pass

        unhealthy_deployments = await self._async_get_cooldown_deployments()
        healthy_deployments: list = []
        for deployment in _all_deployments:
            if deployment["model_info"]["id"] in unhealthy_deployments:
                continue
            else:
                healthy_deployments.append(deployment)
        return healthy_deployments

    def routing_strategy_pre_call_checks(self, deployment: dict):
        """
        Mimics 'async_routing_strategy_pre_call_checks'

        Ensures consistent update rpm implementation for 'usage-based-routing-v2'

        Returns:
        - None

        Raises:
        - Rate Limit Exception - If the deployment is over it's tpm/rpm limits
        """
        for _callback in litellm.callbacks:
            if isinstance(_callback, CustomLogger):
                response = _callback.pre_call_check(deployment)

    async def async_routing_strategy_pre_call_checks(self, deployment: dict):
        """
        For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.

        -> makes the calls concurrency-safe, when rpm limits are set for a deployment

        Returns:
        - None

        Raises:
        - Rate Limit Exception - If the deployment is over it's tpm/rpm limits
        """
        for _callback in litellm.callbacks:
            if isinstance(_callback, CustomLogger):
                response = await _callback.async_pre_call_check(deployment)

    def set_client(self, model: dict):
        """
        - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
        - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
        """
        client_ttl = self.client_ttl
        litellm_params = model.get("litellm_params", {})
        model_name = litellm_params.get("model")
        model_id = model["model_info"]["id"]
        # ### IF RPM SET - initialize a semaphore ###
        rpm = litellm_params.get("rpm", None)
        tpm = litellm_params.get("tpm", None)
        max_parallel_requests = litellm_params.get("max_parallel_requests", None)
        calculated_max_parallel_requests = calculate_max_parallel_requests(
            rpm=rpm,
            max_parallel_requests=max_parallel_requests,
            tpm=tpm,
            default_max_parallel_requests=self.default_max_parallel_requests,
        )
        if calculated_max_parallel_requests:
            semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
            cache_key = f"{model_id}_max_parallel_requests_client"
            self.cache.set_cache(
                key=cache_key,
                value=semaphore,
                local_only=True,
            )

        ####  for OpenAI / Azure we need to initalize the Client for High Traffic ########
        custom_llm_provider = litellm_params.get("custom_llm_provider")
        custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
        default_api_base = None
        default_api_key = None
        if custom_llm_provider in litellm.openai_compatible_providers:
            _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
                model=model_name
            )
            default_api_base = api_base
            default_api_key = api_key

        if (
            model_name in litellm.open_ai_chat_completion_models
            or custom_llm_provider in litellm.openai_compatible_providers
            or custom_llm_provider == "azure"
            or custom_llm_provider == "azure_text"
            or custom_llm_provider == "custom_openai"
            or custom_llm_provider == "openai"
            or custom_llm_provider == "text-completion-openai"
            or "ft:gpt-3.5-turbo" in model_name
            or model_name in litellm.open_ai_embedding_models
        ):
            is_azure_ai_studio_model: bool = False
            if custom_llm_provider == "azure":
                if litellm.utils._is_non_openai_azure_model(model_name):
                    is_azure_ai_studio_model = True
                    custom_llm_provider = "openai"
                    # remove azure prefx from model_name
                    model_name = model_name.replace("azure/", "")
            # glorified / complicated reading of configs
            # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
            # we do this here because we init clients for Azure, OpenAI and we need to set the right key
            api_key = litellm_params.get("api_key") or default_api_key
            if (
                api_key
                and isinstance(api_key, str)
                and api_key.startswith("os.environ/")
            ):
                api_key_env_name = api_key.replace("os.environ/", "")
                api_key = litellm.get_secret(api_key_env_name)
                litellm_params["api_key"] = api_key

            api_base = litellm_params.get("api_base")
            base_url = litellm_params.get("base_url")
            api_base = (
                api_base or base_url or default_api_base
            )  # allow users to pass in `api_base` or `base_url` for azure
            if api_base and api_base.startswith("os.environ/"):
                api_base_env_name = api_base.replace("os.environ/", "")
                api_base = litellm.get_secret(api_base_env_name)
                litellm_params["api_base"] = api_base

            ## AZURE AI STUDIO MISTRAL CHECK ##
            """
            Make sure api base ends in /v1/

            if not, add it - https://github.com/BerriAI/litellm/issues/2279
            """
            if (
                is_azure_ai_studio_model == True
                and api_base is not None
                and isinstance(api_base, str)
                and not api_base.endswith("/v1/")
            ):
                # check if it ends with a trailing slash
                if api_base.endswith("/"):
                    api_base += "v1/"
                elif api_base.endswith("/v1"):
                    api_base += "/"
                else:
                    api_base += "/v1/"

            api_version = litellm_params.get("api_version")
            if api_version and api_version.startswith("os.environ/"):
                api_version_env_name = api_version.replace("os.environ/", "")
                api_version = litellm.get_secret(api_version_env_name)
                litellm_params["api_version"] = api_version

            timeout = litellm_params.pop("timeout", None)
            if isinstance(timeout, str) and timeout.startswith("os.environ/"):
                timeout_env_name = timeout.replace("os.environ/", "")
                timeout = litellm.get_secret(timeout_env_name)
                litellm_params["timeout"] = timeout

            stream_timeout = litellm_params.pop(
                "stream_timeout", timeout
            )  # if no stream_timeout is set, default to timeout
            if isinstance(stream_timeout, str) and stream_timeout.startswith(
                "os.environ/"
            ):
                stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
                stream_timeout = litellm.get_secret(stream_timeout_env_name)
                litellm_params["stream_timeout"] = stream_timeout

            max_retries = litellm_params.pop(
                "max_retries", 0
            )  # router handles retry logic
            if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
                max_retries_env_name = max_retries.replace("os.environ/", "")
                max_retries = litellm.get_secret(max_retries_env_name)
                litellm_params["max_retries"] = max_retries

            # proxy support
            import os
            import httpx

            # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
            http_proxy = os.getenv("HTTP_PROXY", None)
            https_proxy = os.getenv("HTTPS_PROXY", None)
            no_proxy = os.getenv("NO_PROXY", None)

            # Create the proxies dictionary only if the environment variables are set.
            sync_proxy_mounts = None
            async_proxy_mounts = None
            if http_proxy is not None and https_proxy is not None:
                sync_proxy_mounts = {
                    "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
                    "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
                }
                async_proxy_mounts = {
                    "http://": httpx.AsyncHTTPTransport(
                        proxy=httpx.Proxy(url=http_proxy)
                    ),
                    "https://": httpx.AsyncHTTPTransport(
                        proxy=httpx.Proxy(url=https_proxy)
                    ),
                }

                # assume no_proxy is a list of comma separated urls
                if no_proxy is not None and isinstance(no_proxy, str):
                    no_proxy_urls = no_proxy.split(",")

                    for url in no_proxy_urls:  # set no-proxy support for specific urls
                        sync_proxy_mounts[url] = None  # type: ignore
                        async_proxy_mounts[url] = None  # type: ignore

            organization = litellm_params.get("organization", None)
            if isinstance(organization, str) and organization.startswith("os.environ/"):
                organization_env_name = organization.replace("os.environ/", "")
                organization = litellm.get_secret(organization_env_name)
                litellm_params["organization"] = organization

            if "azure" in model_name:
                if api_base is None or not isinstance(api_base, str):
                    raise ValueError(
                        f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
                    )
                azure_ad_token = litellm_params.get("azure_ad_token")
                if azure_ad_token is not None:
                    if azure_ad_token.startswith("oidc/"):
                        azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
                if api_version is None:
                    api_version = "2023-07-01-preview"

                if "gateway.ai.cloudflare.com" in api_base:
                    if not api_base.endswith("/"):
                        api_base += "/"
                    azure_model = model_name.replace("azure/", "")
                    api_base += f"{azure_model}"
                    cache_key = f"{model_id}_async_client"
                    _client = openai.AsyncAzureOpenAI(
                        api_key=api_key,
                        azure_ad_token=azure_ad_token,
                        base_url=api_base,
                        api_version=api_version,
                        timeout=timeout,
                        max_retries=max_retries,
                        http_client=httpx.AsyncClient(
                            transport=AsyncCustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=async_proxy_mounts,
                        ),  # type: ignore
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr

                    cache_key = f"{model_id}_client"
                    _client = openai.AzureOpenAI(  # type: ignore
                        api_key=api_key,
                        azure_ad_token=azure_ad_token,
                        base_url=api_base,
                        api_version=api_version,
                        timeout=timeout,
                        max_retries=max_retries,
                        http_client=httpx.Client(
                            transport=CustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=sync_proxy_mounts,
                        ),  # type: ignore
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr
                    # streaming clients can have diff timeouts
                    cache_key = f"{model_id}_stream_async_client"
                    _client = openai.AsyncAzureOpenAI(  # type: ignore
                        api_key=api_key,
                        azure_ad_token=azure_ad_token,
                        base_url=api_base,
                        api_version=api_version,
                        timeout=stream_timeout,
                        max_retries=max_retries,
                        http_client=httpx.AsyncClient(
                            transport=AsyncCustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=async_proxy_mounts,
                        ),  # type: ignore
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr

                    cache_key = f"{model_id}_stream_client"
                    _client = openai.AzureOpenAI(  # type: ignore
                        api_key=api_key,
                        azure_ad_token=azure_ad_token,
                        base_url=api_base,
                        api_version=api_version,
                        timeout=stream_timeout,
                        max_retries=max_retries,
                        http_client=httpx.Client(
                            transport=CustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=sync_proxy_mounts,
                        ),  # type: ignore
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr
                else:
                    _api_key = api_key
                    if _api_key is not None and isinstance(_api_key, str):
                        # only show first 5 chars of api_key
                        _api_key = _api_key[:8] + "*" * 15
                    verbose_router_logger.debug(
                        f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
                    )
                    azure_client_params = {
                        "api_key": api_key,
                        "azure_endpoint": api_base,
                        "api_version": api_version,
                        "azure_ad_token": azure_ad_token,
                    }
                    from litellm.llms.azure import select_azure_base_url_or_endpoint

                    # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
                    # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
                    azure_client_params = select_azure_base_url_or_endpoint(
                        azure_client_params
                    )

                    cache_key = f"{model_id}_async_client"
                    _client = openai.AsyncAzureOpenAI(  # type: ignore
                        **azure_client_params,
                        timeout=timeout,
                        max_retries=max_retries,
                        http_client=httpx.AsyncClient(
                            transport=AsyncCustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=async_proxy_mounts,
                        ),  # type: ignore
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr

                    cache_key = f"{model_id}_client"
                    _client = openai.AzureOpenAI(  # type: ignore
                        **azure_client_params,
                        timeout=timeout,
                        max_retries=max_retries,
                        http_client=httpx.Client(
                            transport=CustomHTTPTransport(
                                verify=litellm.ssl_verify,
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                            ),
                            mounts=sync_proxy_mounts,
                        ),  # type: ignore
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr

                    # streaming clients should have diff timeouts
                    cache_key = f"{model_id}_stream_async_client"
                    _client = openai.AsyncAzureOpenAI(  # type: ignore
                        **azure_client_params,
                        timeout=stream_timeout,
                        max_retries=max_retries,
                        http_client=httpx.AsyncClient(
                            transport=AsyncCustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=async_proxy_mounts,
                        ),
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr

                    cache_key = f"{model_id}_stream_client"
                    _client = openai.AzureOpenAI(  # type: ignore
                        **azure_client_params,
                        timeout=stream_timeout,
                        max_retries=max_retries,
                        http_client=httpx.Client(
                            transport=CustomHTTPTransport(
                                limits=httpx.Limits(
                                    max_connections=1000, max_keepalive_connections=100
                                ),
                                verify=litellm.ssl_verify,
                            ),
                            mounts=sync_proxy_mounts,
                        ),
                    )
                    self.cache.set_cache(
                        key=cache_key,
                        value=_client,
                        ttl=client_ttl,
                        local_only=True,
                    )  # cache for 1 hr

            else:
                _api_key = api_key  # type: ignore
                if _api_key is not None and isinstance(_api_key, str):
                    # only show first 5 chars of api_key
                    _api_key = _api_key[:8] + "*" * 15
                verbose_router_logger.debug(
                    f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
                )
                cache_key = f"{model_id}_async_client"
                _client = openai.AsyncOpenAI(  # type: ignore
                    api_key=api_key,
                    base_url=api_base,
                    timeout=timeout,
                    max_retries=max_retries,
                    organization=organization,
                    http_client=httpx.AsyncClient(
                        transport=AsyncCustomHTTPTransport(
                            limits=httpx.Limits(
                                max_connections=1000, max_keepalive_connections=100
                            ),
                            verify=litellm.ssl_verify,
                        ),
                        mounts=async_proxy_mounts,
                    ),  # type: ignore
                )
                self.cache.set_cache(
                    key=cache_key,
                    value=_client,
                    ttl=client_ttl,
                    local_only=True,
                )  # cache for 1 hr

                cache_key = f"{model_id}_client"
                _client = openai.OpenAI(  # type: ignore
                    api_key=api_key,
                    base_url=api_base,
                    timeout=timeout,
                    max_retries=max_retries,
                    organization=organization,
                    http_client=httpx.Client(
                        transport=CustomHTTPTransport(
                            limits=httpx.Limits(
                                max_connections=1000, max_keepalive_connections=100
                            ),
                            verify=litellm.ssl_verify,
                        ),
                        mounts=sync_proxy_mounts,
                    ),  # type: ignore
                )
                self.cache.set_cache(
                    key=cache_key,
                    value=_client,
                    ttl=client_ttl,
                    local_only=True,
                )  # cache for 1 hr

                # streaming clients should have diff timeouts
                cache_key = f"{model_id}_stream_async_client"
                _client = openai.AsyncOpenAI(  # type: ignore
                    api_key=api_key,
                    base_url=api_base,
                    timeout=stream_timeout,
                    max_retries=max_retries,
                    organization=organization,
                    http_client=httpx.AsyncClient(
                        transport=AsyncCustomHTTPTransport(
                            limits=httpx.Limits(
                                max_connections=1000, max_keepalive_connections=100
                            ),
                            verify=litellm.ssl_verify,
                        ),
                        mounts=async_proxy_mounts,
                    ),  # type: ignore
                )
                self.cache.set_cache(
                    key=cache_key,
                    value=_client,
                    ttl=client_ttl,
                    local_only=True,
                )  # cache for 1 hr

                # streaming clients should have diff timeouts
                cache_key = f"{model_id}_stream_client"
                _client = openai.OpenAI(  # type: ignore
                    api_key=api_key,
                    base_url=api_base,
                    timeout=stream_timeout,
                    max_retries=max_retries,
                    organization=organization,
                    http_client=httpx.Client(
                        transport=CustomHTTPTransport(
                            limits=httpx.Limits(
                                max_connections=1000, max_keepalive_connections=100
                            ),
                            verify=litellm.ssl_verify,
                        ),
                        mounts=sync_proxy_mounts,
                    ),  # type: ignore
                )
                self.cache.set_cache(
                    key=cache_key,
                    value=_client,
                    ttl=client_ttl,
                    local_only=True,
                )  # cache for 1 hr

    def _generate_model_id(self, model_group: str, litellm_params: dict):
        """
        Helper function to consistently generate the same id for a deployment

        - create a string from all the litellm params
        - hash
        - use hash as id
        """
        concat_str = model_group
        for k, v in litellm_params.items():
            if isinstance(k, str):
                concat_str += k
            elif isinstance(k, dict):
                concat_str += json.dumps(k)
            else:
                concat_str += str(k)

            if isinstance(v, str):
                concat_str += v
            elif isinstance(v, dict):
                concat_str += json.dumps(v)
            else:
                concat_str += str(v)

        hash_object = hashlib.sha256(concat_str.encode())

        return hash_object.hexdigest()

    def set_model_list(self, model_list: list):
        original_model_list = copy.deepcopy(model_list)
        self.model_list = []
        # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
        import os

        for model in original_model_list:
            _model_name = model.pop("model_name")
            _litellm_params = model.pop("litellm_params")
            ## check if litellm params in os.environ
            if isinstance(_litellm_params, dict):
                for k, v in _litellm_params.items():
                    if isinstance(v, str) and v.startswith("os.environ/"):
                        _litellm_params[k] = litellm.get_secret(v)

            _model_info: dict = model.pop("model_info", {})

            # check if model info has id
            if "id" not in _model_info:
                _id = self._generate_model_id(_model_name, _litellm_params)
                _model_info["id"] = _id

            deployment = Deployment(
                **model,
                model_name=_model_name,
                litellm_params=_litellm_params,
                model_info=_model_info,
            )

            deployment = self._add_deployment(deployment=deployment)

            model = deployment.to_json(exclude_none=True)

            self.model_list.append(model)

        verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
        self.model_names = [m["model_name"] for m in model_list]

    def _add_deployment(self, deployment: Deployment) -> Deployment:
        import os

        #### DEPLOYMENT NAMES INIT ########
        self.deployment_names.append(deployment.litellm_params.model)
        ############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
        # for get_available_deployment, we use the litellm_param["rpm"]
        # in this snippet we also set rpm to be a litellm_param
        if (
            deployment.litellm_params.rpm is None
            and getattr(deployment, "rpm", None) is not None
        ):
            deployment.litellm_params.rpm = getattr(deployment, "rpm")

        if (
            deployment.litellm_params.tpm is None
            and getattr(deployment, "tpm", None) is not None
        ):
            deployment.litellm_params.tpm = getattr(deployment, "tpm")

        #### VALIDATE MODEL ########
        # check if model provider in supported providers
        (
            _model,
            custom_llm_provider,
            dynamic_api_key,
            api_base,
        ) = litellm.get_llm_provider(
            model=deployment.litellm_params.model,
            custom_llm_provider=deployment.litellm_params.get(
                "custom_llm_provider", None
            ),
        )

        # Check if user is trying to use model_name == "*"
        # this is a catch all model for their specific api key
        if deployment.model_name == "*":
            self.default_deployment = deployment.to_json(exclude_none=True)

        # Azure GPT-Vision Enhancements, users can pass os.environ/
        data_sources = deployment.litellm_params.get("dataSources", []) or []

        for data_source in data_sources:
            params = data_source.get("parameters", {})
            for param_key in ["endpoint", "key"]:
                # if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
                if param_key in params and params[param_key].startswith("os.environ/"):
                    env_name = params[param_key].replace("os.environ/", "")
                    params[param_key] = os.environ.get(env_name, "")

        # done reading model["litellm_params"]
        if custom_llm_provider not in litellm.provider_list:
            raise Exception(f"Unsupported provider - {custom_llm_provider}")

        # init OpenAI, Azure clients
        self.set_client(model=deployment.to_json(exclude_none=True))

        # set region (if azure model) ## PREVIEW FEATURE ##
        if litellm.enable_preview_features == True:
            print("Auto inferring region")  # noqa
            """
            Hiding behind a feature flag
            When there is a large amount of LLM deployments this makes startup times blow up
            """
            try:
                if (
                    "azure" in deployment.litellm_params.model
                    and deployment.litellm_params.region_name is None
                ):
                    region = litellm.utils.get_model_region(
                        litellm_params=deployment.litellm_params, mode=None
                    )

                    deployment.litellm_params.region_name = region
            except Exception as e:
                verbose_router_logger.debug(
                    "Unable to get the region for azure model - {}, {}".format(
                        deployment.litellm_params.model, str(e)
                    )
                )
                pass  # [NON-BLOCKING]

        return deployment

    def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
        """
        Parameters:
        - deployment: Deployment - the deployment to be added to the Router

        Returns:
        - The added deployment
        - OR None (if deployment already exists)
        """
        # check if deployment already exists

        if deployment.model_info.id in self.get_model_ids():
            return None

        # add to model list
        _deployment = deployment.to_json(exclude_none=True)
        self.model_list.append(_deployment)

        # initialize client
        self._add_deployment(deployment=deployment)

        # add to model names
        self.model_names.append(deployment.model_name)
        return deployment

    def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
        """
        Add or update deployment
        Parameters:
        - deployment: Deployment - the deployment to be added to the Router

        Returns:
        - The added/updated deployment
        """
        # check if deployment already exists
        _deployment_model_id = deployment.model_info.id or ""
        _deployment_on_router: Optional[Deployment] = self.get_deployment(
            model_id=_deployment_model_id
        )
        if _deployment_on_router is not None:
            # deployment with this model_id exists on the router
            if deployment.litellm_params == _deployment_on_router.litellm_params:
                # No need to update
                return None

            # if there is a new litellm param -> then update the deployment
            # remove the previous deployment
            removal_idx: Optional[int] = None
            for idx, model in enumerate(self.model_list):
                if model["model_info"]["id"] == deployment.model_info.id:
                    removal_idx = idx

            if removal_idx is not None:
                self.model_list.pop(removal_idx)
        else:
            # if the model_id is not in router
            self.add_deployment(deployment=deployment)
        return deployment

    def delete_deployment(self, id: str) -> Optional[Deployment]:
        """
        Parameters:
        - id: str - the id of the deployment to be deleted

        Returns:
        - The deleted deployment
        - OR None (if deleted deployment not found)
        """
        deployment_idx = None
        for idx, m in enumerate(self.model_list):
            if m["model_info"]["id"] == id:
                deployment_idx = idx

        try:
            if deployment_idx is not None:
                item = self.model_list.pop(deployment_idx)
                return item
            else:
                return None
        except:
            return None

    def get_deployment(self, model_id: str) -> Optional[Deployment]:
        """
        Returns -> Deployment or None

        Raise Exception -> if model found in invalid format
        """
        for model in self.model_list:
            if "model_info" in model and "id" in model["model_info"]:
                if model_id == model["model_info"]["id"]:
                    if isinstance(model, dict):
                        return Deployment(**model)
                    elif isinstance(model, Deployment):
                        return model
                    else:
                        raise Exception("Model invalid format - {}".format(type(model)))
        return None

    def get_model_info(self, id: str) -> Optional[dict]:
        """
        For a given model id, return the model info
        """
        for model in self.model_list:
            if "model_info" in model and "id" in model["model_info"]:
                if id == model["model_info"]["id"]:
                    return model
        return None

    def get_model_ids(self) -> List[str]:
        """
        Returns list of model id's.
        """
        ids = []
        for model in self.model_list:
            if "model_info" in model and "id" in model["model_info"]:
                id = model["model_info"]["id"]
                ids.append(id)
        return ids

    def get_model_names(self) -> List[str]:
        return self.model_names

    def get_model_list(self):
        if hasattr(self, "model_list"):
            return self.model_list
        return None

    def get_settings(self):
        """
        Get router settings method, returns a dictionary of the settings and their values.
        For example get the set values for routing_strategy_args, routing_strategy, allowed_fails, cooldown_time, num_retries, timeout, max_retries, retry_after
        """
        _all_vars = vars(self)
        _settings_to_return = {}
        vars_to_include = [
            "routing_strategy_args",
            "routing_strategy",
            "allowed_fails",
            "cooldown_time",
            "num_retries",
            "timeout",
            "max_retries",
            "retry_after",
            "fallbacks",
            "context_window_fallbacks",
            "model_group_retry_policy",
        ]

        for var in vars_to_include:
            if var in _all_vars:
                _settings_to_return[var] = _all_vars[var]
            if (
                var == "routing_strategy_args"
                and self.routing_strategy == "latency-based-routing"
            ):
                _settings_to_return[var] = self.lowestlatency_logger.routing_args.json()
        return _settings_to_return

    def update_settings(self, **kwargs):
        # only the following settings are allowed to be configured
        _allowed_settings = [
            "routing_strategy_args",
            "routing_strategy",
            "allowed_fails",
            "cooldown_time",
            "num_retries",
            "timeout",
            "max_retries",
            "retry_after",
            "fallbacks",
            "context_window_fallbacks",
            "model_group_retry_policy",
        ]

        _int_settings = [
            "timeout",
            "num_retries",
            "retry_after",
            "allowed_fails",
            "cooldown_time",
        ]

        _existing_router_settings = self.get_settings()
        for var in kwargs:
            if var in _allowed_settings:
                if var in _int_settings:
                    _casted_value = int(kwargs[var])
                    setattr(self, var, _casted_value)
                else:
                    # only run routing strategy init if it has changed
                    if (
                        var == "routing_strategy"
                        and _existing_router_settings["routing_strategy"] != kwargs[var]
                    ):
                        self.routing_strategy_init(
                            routing_strategy=kwargs[var],
                            routing_strategy_args=kwargs.get(
                                "routing_strategy_args", {}
                            ),
                        )
                    setattr(self, var, kwargs[var])
            else:
                verbose_router_logger.debug("Setting {} is not allowed".format(var))
        verbose_router_logger.debug(f"Updated Router settings: {self.get_settings()}")

    def _get_client(self, deployment, kwargs, client_type=None):
        """
        Returns the appropriate client based on the given deployment, kwargs, and client_type.

        Parameters:
            deployment (dict): The deployment dictionary containing the clients.
            kwargs (dict): The keyword arguments passed to the function.
            client_type (str): The type of client to return.

        Returns:
            The appropriate client based on the given client_type and kwargs.
        """
        model_id = deployment["model_info"]["id"]
        if client_type == "max_parallel_requests":
            cache_key = "{}_max_parallel_requests_client".format(model_id)
            client = self.cache.get_cache(key=cache_key, local_only=True)
            return client
        elif client_type == "async":
            if kwargs.get("stream") == True:
                cache_key = f"{model_id}_stream_async_client"
                client = self.cache.get_cache(key=cache_key, local_only=True)
                if client is None:
                    """
                    Re-initialize the client
                    """
                    self.set_client(model=deployment)
                    client = self.cache.get_cache(key=cache_key, local_only=True)
                return client
            else:
                cache_key = f"{model_id}_async_client"
                client = self.cache.get_cache(key=cache_key, local_only=True)
                if client is None:
                    """
                    Re-initialize the client
                    """
                    self.set_client(model=deployment)
                    client = self.cache.get_cache(key=cache_key, local_only=True)
                return client
        else:
            if kwargs.get("stream") == True:
                cache_key = f"{model_id}_stream_client"
                client = self.cache.get_cache(key=cache_key)
                if client is None:
                    """
                    Re-initialize the client
                    """
                    self.set_client(model=deployment)
                    client = self.cache.get_cache(key=cache_key)
                return client
            else:
                cache_key = f"{model_id}_client"
                client = self.cache.get_cache(key=cache_key)
                if client is None:
                    """
                    Re-initialize the client
                    """
                    self.set_client(model=deployment)
                    client = self.cache.get_cache(key=cache_key)
                return client

    def _pre_call_checks(
        self,
        model: str,
        healthy_deployments: List,
        messages: List[Dict[str, str]],
        request_kwargs: Optional[dict] = None,
    ):
        """
        Filter out model in model group, if:

        - model context window < message length
        - filter models above rpm limits
        - if region given, filter out models not in that region / unknown region
        - [TODO] function call and model doesn't support function calling
        """

        verbose_router_logger.debug(
            f"Starting Pre-call checks for deployments in model={model}"
        )

        _returned_deployments = copy.deepcopy(healthy_deployments)

        invalid_model_indices = []

        try:
            input_tokens = litellm.token_counter(messages=messages)
        except Exception as e:
            return _returned_deployments

        _context_window_error = False
        _rate_limit_error = False

        ## get model group RPM ##
        dt = get_utc_datetime()
        current_minute = dt.strftime("%H-%M")
        rpm_key = f"{model}:rpm:{current_minute}"
        model_group_cache = (
            self.cache.get_cache(key=rpm_key, local_only=True) or {}
        )  # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
        for idx, deployment in enumerate(_returned_deployments):
            # see if we have the info for this model
            try:
                base_model = deployment.get("model_info", {}).get("base_model", None)
                if base_model is None:
                    base_model = deployment.get("litellm_params", {}).get(
                        "base_model", None
                    )
                model = base_model or deployment.get("litellm_params", {}).get(
                    "model", None
                )
                model_info = litellm.get_model_info(model=model)

                if (
                    isinstance(model_info, dict)
                    and model_info.get("max_input_tokens", None) is not None
                ):
                    if (
                        isinstance(model_info["max_input_tokens"], int)
                        and input_tokens > model_info["max_input_tokens"]
                    ):
                        invalid_model_indices.append(idx)
                        _context_window_error = True
                        continue
            except Exception as e:
                verbose_router_logger.debug("An error occurs - {}".format(str(e)))

            _litellm_params = deployment.get("litellm_params", {})
            model_id = deployment.get("model_info", {}).get("id", "")
            ## RPM CHECK ##
            ### get local router cache ###
            current_request_cache_local = (
                self.cache.get_cache(key=model_id, local_only=True) or 0
            )
            ### get usage based cache ###
            if (
                isinstance(model_group_cache, dict)
                and self.routing_strategy != "usage-based-routing-v2"
            ):
                model_group_cache[model_id] = model_group_cache.get(model_id, 0)

                current_request = max(
                    current_request_cache_local, model_group_cache[model_id]
                )

                if (
                    isinstance(_litellm_params, dict)
                    and _litellm_params.get("rpm", None) is not None
                ):
                    if (
                        isinstance(_litellm_params["rpm"], int)
                        and _litellm_params["rpm"] <= current_request
                    ):
                        invalid_model_indices.append(idx)
                        _rate_limit_error = True
                        continue

            ## REGION CHECK ##
            if (
                request_kwargs is not None
                and request_kwargs.get("allowed_model_region") is not None
                and request_kwargs["allowed_model_region"] == "eu"
            ):
                if _litellm_params.get("region_name") is not None and isinstance(
                    _litellm_params["region_name"], str
                ):
                    # check if in allowed_model_region
                    if (
                        _is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
                        == False
                    ):
                        invalid_model_indices.append(idx)
                        continue
                else:
                    verbose_router_logger.debug(
                        "Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
                            model_id, request_kwargs.get("allowed_model_region")
                        )
                    )
                    # filter out since region unknown, and user wants to filter for specific region
                    invalid_model_indices.append(idx)
                    continue

            ## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_object' param
            if request_kwargs is not None and litellm.drop_params == False:
                # get supported params
                model, custom_llm_provider, _, _ = litellm.get_llm_provider(
                    model=model, litellm_params=LiteLLM_Params(**_litellm_params)
                )

                supported_openai_params = litellm.get_supported_openai_params(
                    model=model, custom_llm_provider=custom_llm_provider
                )

                if supported_openai_params is None:
                    continue
                else:
                    # check the non-default openai params in request kwargs
                    non_default_params = litellm.utils.get_non_default_params(
                        passed_params=request_kwargs
                    )
                    # check if all params are supported
                    for k, v in non_default_params.items():
                        if k not in supported_openai_params:
                            # if not -> invalid model
                            invalid_model_indices.append(idx)

        if len(invalid_model_indices) == len(_returned_deployments):
            """
            - no healthy deployments available b/c context window checks or rate limit error

            - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
            """

            if _rate_limit_error == True:  # allow generic fallback logic to take place
                raise ValueError(
                    f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
                )
            elif _context_window_error == True:
                raise litellm.ContextWindowExceededError(
                    message="Context Window exceeded for given call",
                    model=model,
                    llm_provider="",
                    response=httpx.Response(
                        status_code=400,
                        request=httpx.Request("GET", "https://example.com"),
                    ),
                )
        if len(invalid_model_indices) > 0:
            for idx in reversed(invalid_model_indices):
                _returned_deployments.pop(idx)

        return _returned_deployments

    def _common_checks_available_deployment(
        self,
        model: str,
        messages: Optional[List[Dict[str, str]]] = None,
        input: Optional[Union[str, List]] = None,
        specific_deployment: Optional[bool] = False,
    ) -> Tuple[str, Union[list, dict]]:
        """
        Common checks for 'get_available_deployment' across sync + async call.

        If 'healthy_deployments' returned is None, this means the user chose a specific deployment

        Returns
        - Dict, if specific model chosen
        - List, if multiple models chosen
        """
        # check if aliases set on litellm model alias map
        if specific_deployment == True:
            # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
            for deployment in self.model_list:
                deployment_model = deployment.get("litellm_params").get("model")
                if deployment_model == model:
                    # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
                    # return the first deployment where the `model` matches the specificed deployment name
                    return deployment_model, deployment
            raise ValueError(
                f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
            )

        if model in self.model_group_alias:
            verbose_router_logger.debug(
                f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
            )
            model = self.model_group_alias[model]

        if model not in self.model_names and self.default_deployment is not None:
            updated_deployment = copy.deepcopy(
                self.default_deployment
            )  # self.default_deployment
            updated_deployment["litellm_params"]["model"] = model
            return model, updated_deployment

        ## get healthy deployments
        ### get all deployments
        healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
        if len(healthy_deployments) == 0:
            # check if the user sent in a deployment name instead
            healthy_deployments = [
                m for m in self.model_list if m["litellm_params"]["model"] == model
            ]

        litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")

        if len(healthy_deployments) == 0:
            raise ValueError(
                f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
            )
        if litellm.model_alias_map and model in litellm.model_alias_map:
            model = litellm.model_alias_map[
                model
            ]  # update the model to the actual value if an alias has been passed in

        return model, healthy_deployments

    async def async_get_available_deployment(
        self,
        model: str,
        messages: Optional[List[Dict[str, str]]] = None,
        input: Optional[Union[str, List]] = None,
        specific_deployment: Optional[bool] = False,
        request_kwargs: Optional[Dict] = None,
    ):
        """
        Async implementation of 'get_available_deployments'.

        Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
        """
        if (
            self.routing_strategy != "usage-based-routing-v2"
            and self.routing_strategy != "simple-shuffle"
            and self.routing_strategy != "cost-based-routing"
        ):  # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
            return self.get_available_deployment(
                model=model,
                messages=messages,
                input=input,
                specific_deployment=specific_deployment,
                request_kwargs=request_kwargs,
            )

        model, healthy_deployments = self._common_checks_available_deployment(
            model=model,
            messages=messages,
            input=input,
            specific_deployment=specific_deployment,
        )  # type: ignore

        if isinstance(healthy_deployments, dict):
            return healthy_deployments

        # filter out the deployments currently cooling down
        deployments_to_remove = []
        # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
        cooldown_deployments = await self._async_get_cooldown_deployments()
        verbose_router_logger.debug(
            f"async cooldown deployments: {cooldown_deployments}"
        )
        # Find deployments in model_list whose model_id is cooling down
        for deployment in healthy_deployments:
            deployment_id = deployment["model_info"]["id"]
            if deployment_id in cooldown_deployments:
                deployments_to_remove.append(deployment)
        # remove unhealthy deployments from healthy deployments
        for deployment in deployments_to_remove:
            healthy_deployments.remove(deployment)

        # filter pre-call checks
        _allowed_model_region = (
            request_kwargs.get("allowed_model_region")
            if request_kwargs is not None
            else None
        )

        if self.enable_pre_call_checks and messages is not None:
            healthy_deployments = self._pre_call_checks(
                model=model,
                healthy_deployments=healthy_deployments,
                messages=messages,
                request_kwargs=request_kwargs,
            )

        if len(healthy_deployments) == 0:
            if _allowed_model_region is None:
                _allowed_model_region = "n/a"
            raise ValueError(
                f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
            )

        if (
            self.routing_strategy == "usage-based-routing-v2"
            and self.lowesttpm_logger_v2 is not None
        ):
            deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
                model_group=model,
                healthy_deployments=healthy_deployments,  # type: ignore
                messages=messages,
                input=input,
            )
        if (
            self.routing_strategy == "cost-based-routing"
            and self.lowestcost_logger is not None
        ):
            deployment = await self.lowestcost_logger.async_get_available_deployments(
                model_group=model,
                healthy_deployments=healthy_deployments,  # type: ignore
                messages=messages,
                input=input,
            )
        elif self.routing_strategy == "simple-shuffle":
            # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
            ############## Check if we can do a RPM/TPM based weighted pick #################
            rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
            if rpm is not None:
                # use weight-random pick if rpms provided
                rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
                verbose_router_logger.debug(f"\nrpms {rpms}")
                total_rpm = sum(rpms)
                weights = [rpm / total_rpm for rpm in rpms]
                verbose_router_logger.debug(f"\n weights {weights}")
                # Perform weighted random pick
                selected_index = random.choices(range(len(rpms)), weights=weights)[0]
                verbose_router_logger.debug(f"\n selected index, {selected_index}")
                deployment = healthy_deployments[selected_index]
                verbose_router_logger.info(
                    f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
                )
                return deployment or deployment[0]
            ############## Check if we can do a RPM/TPM based weighted pick #################
            tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
            if tpm is not None:
                # use weight-random pick if rpms provided
                tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
                verbose_router_logger.debug(f"\ntpms {tpms}")
                total_tpm = sum(tpms)
                weights = [tpm / total_tpm for tpm in tpms]
                verbose_router_logger.debug(f"\n weights {weights}")
                # Perform weighted random pick
                selected_index = random.choices(range(len(tpms)), weights=weights)[0]
                verbose_router_logger.debug(f"\n selected index, {selected_index}")
                deployment = healthy_deployments[selected_index]
                verbose_router_logger.info(
                    f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
                )
                return deployment or deployment[0]

            ############## No RPM/TPM passed, we do a random pick #################
            item = random.choice(healthy_deployments)
            return item or item[0]
        if deployment is None:
            verbose_router_logger.info(
                f"get_available_deployment for model: {model}, No deployment available"
            )
            raise ValueError(
                f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
            )
        verbose_router_logger.info(
            f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
        )

        return deployment

    def get_available_deployment(
        self,
        model: str,
        messages: Optional[List[Dict[str, str]]] = None,
        input: Optional[Union[str, List]] = None,
        specific_deployment: Optional[bool] = False,
        request_kwargs: Optional[Dict] = None,
    ):
        """
        Returns the deployment based on routing strategy
        """
        # users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
        # When this was no explicit we had several issues with fallbacks timing out

        model, healthy_deployments = self._common_checks_available_deployment(
            model=model,
            messages=messages,
            input=input,
            specific_deployment=specific_deployment,
        )

        if isinstance(healthy_deployments, dict):
            return healthy_deployments

        # filter out the deployments currently cooling down
        deployments_to_remove = []
        # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
        cooldown_deployments = self._get_cooldown_deployments()
        verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}")
        # Find deployments in model_list whose model_id is cooling down
        for deployment in healthy_deployments:
            deployment_id = deployment["model_info"]["id"]
            if deployment_id in cooldown_deployments:
                deployments_to_remove.append(deployment)
        # remove unhealthy deployments from healthy deployments
        for deployment in deployments_to_remove:
            healthy_deployments.remove(deployment)

        # filter pre-call checks
        if self.enable_pre_call_checks and messages is not None:
            healthy_deployments = self._pre_call_checks(
                model=model, healthy_deployments=healthy_deployments, messages=messages
            )

        if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
            deployment = self.leastbusy_logger.get_available_deployments(
                model_group=model, healthy_deployments=healthy_deployments  # type: ignore
            )
        elif self.routing_strategy == "simple-shuffle":
            # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
            ############## Check if we can do a RPM/TPM based weighted pick #################
            rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
            if rpm is not None:
                # use weight-random pick if rpms provided
                rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
                verbose_router_logger.debug(f"\nrpms {rpms}")
                total_rpm = sum(rpms)
                weights = [rpm / total_rpm for rpm in rpms]
                verbose_router_logger.debug(f"\n weights {weights}")
                # Perform weighted random pick
                selected_index = random.choices(range(len(rpms)), weights=weights)[0]
                verbose_router_logger.debug(f"\n selected index, {selected_index}")
                deployment = healthy_deployments[selected_index]
                verbose_router_logger.info(
                    f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
                )
                return deployment or deployment[0]
            ############## Check if we can do a RPM/TPM based weighted pick #################
            tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
            if tpm is not None:
                # use weight-random pick if rpms provided
                tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
                verbose_router_logger.debug(f"\ntpms {tpms}")
                total_tpm = sum(tpms)
                weights = [tpm / total_tpm for tpm in tpms]
                verbose_router_logger.debug(f"\n weights {weights}")
                # Perform weighted random pick
                selected_index = random.choices(range(len(tpms)), weights=weights)[0]
                verbose_router_logger.debug(f"\n selected index, {selected_index}")
                deployment = healthy_deployments[selected_index]
                verbose_router_logger.info(
                    f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
                )
                return deployment or deployment[0]

            ############## No RPM/TPM passed, we do a random pick #################
            item = random.choice(healthy_deployments)
            return item or item[0]
        elif (
            self.routing_strategy == "latency-based-routing"
            and self.lowestlatency_logger is not None
        ):
            deployment = self.lowestlatency_logger.get_available_deployments(
                model_group=model,
                healthy_deployments=healthy_deployments,  # type: ignore
                request_kwargs=request_kwargs,
            )
        elif (
            self.routing_strategy == "usage-based-routing"
            and self.lowesttpm_logger is not None
        ):
            deployment = self.lowesttpm_logger.get_available_deployments(
                model_group=model,
                healthy_deployments=healthy_deployments,  # type: ignore
                messages=messages,
                input=input,
            )
        elif (
            self.routing_strategy == "usage-based-routing-v2"
            and self.lowesttpm_logger_v2 is not None
        ):
            deployment = self.lowesttpm_logger_v2.get_available_deployments(
                model_group=model,
                healthy_deployments=healthy_deployments,  # type: ignore
                messages=messages,
                input=input,
            )
        if deployment is None:
            verbose_router_logger.info(
                f"get_available_deployment for model: {model}, No deployment available"
            )
            raise ValueError(
                f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
            )
        verbose_router_logger.info(
            f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
        )
        return deployment

    def _track_deployment_metrics(self, deployment, response=None):
        try:
            litellm_params = deployment["litellm_params"]
            api_base = litellm_params.get("api_base", "")
            model = litellm_params.get("model", "")

            model_id = deployment.get("model_info", {}).get("id", None)
            if response is None:

                # update self.deployment_stats
                if model_id is not None:
                    self._update_usage(model_id)  # update in-memory cache for tracking
                    if model_id in self.deployment_stats:
                        # only update num_requests
                        self.deployment_stats[model_id]["num_requests"] += 1
                    else:
                        self.deployment_stats[model_id] = {
                            "api_base": api_base,
                            "model": model,
                            "num_requests": 1,
                        }
            else:
                # check response_ms and update num_successes
                if isinstance(response, dict):
                    response_ms = response.get("_response_ms", 0)
                else:
                    response_ms = 0
                if model_id is not None:
                    if model_id in self.deployment_stats:
                        # check if avg_latency exists
                        if "avg_latency" in self.deployment_stats[model_id]:
                            # update avg_latency
                            self.deployment_stats[model_id]["avg_latency"] = (
                                self.deployment_stats[model_id]["avg_latency"]
                                + response_ms
                            ) / self.deployment_stats[model_id]["num_successes"]
                        else:
                            self.deployment_stats[model_id]["avg_latency"] = response_ms

                        # check if num_successes exists
                        if "num_successes" in self.deployment_stats[model_id]:
                            self.deployment_stats[model_id]["num_successes"] += 1
                        else:
                            self.deployment_stats[model_id]["num_successes"] = 1
                    else:
                        self.deployment_stats[model_id] = {
                            "api_base": api_base,
                            "model": model,
                            "num_successes": 1,
                            "avg_latency": response_ms,
                        }
            if self.set_verbose == True and self.debug_level == "DEBUG":
                from pprint import pformat

                # Assuming self.deployment_stats is your dictionary
                formatted_stats = pformat(self.deployment_stats)

                # Assuming verbose_router_logger is your logger
                verbose_router_logger.info(
                    "self.deployment_stats: \n%s", formatted_stats
                )
        except Exception as e:
            verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")

    def get_num_retries_from_retry_policy(
        self, exception: Exception, model_group: Optional[str] = None
    ):
        """
        BadRequestErrorRetries: Optional[int] = None
        AuthenticationErrorRetries: Optional[int] = None
        TimeoutErrorRetries: Optional[int] = None
        RateLimitErrorRetries: Optional[int] = None
        ContentPolicyViolationErrorRetries: Optional[int] = None
        """
        # if we can find the exception then in the retry policy -> return the number of retries
        retry_policy = self.retry_policy
        if (
            self.model_group_retry_policy is not None
            and model_group is not None
            and model_group in self.model_group_retry_policy
        ):
            retry_policy = self.model_group_retry_policy.get(model_group, None)

        if retry_policy is None:
            return None
        if isinstance(retry_policy, dict):
            retry_policy = RetryPolicy(**retry_policy)
        if (
            isinstance(exception, litellm.BadRequestError)
            and retry_policy.BadRequestErrorRetries is not None
        ):
            return retry_policy.BadRequestErrorRetries
        if (
            isinstance(exception, litellm.AuthenticationError)
            and retry_policy.AuthenticationErrorRetries is not None
        ):
            return retry_policy.AuthenticationErrorRetries
        if (
            isinstance(exception, litellm.Timeout)
            and retry_policy.TimeoutErrorRetries is not None
        ):
            return retry_policy.TimeoutErrorRetries
        if (
            isinstance(exception, litellm.RateLimitError)
            and retry_policy.RateLimitErrorRetries is not None
        ):
            return retry_policy.RateLimitErrorRetries
        if (
            isinstance(exception, litellm.ContentPolicyViolationError)
            and retry_policy.ContentPolicyViolationErrorRetries is not None
        ):
            return retry_policy.ContentPolicyViolationErrorRetries

    def _initialize_alerting(self):
        from litellm.integrations.slack_alerting import SlackAlerting

        router_alerting_config: AlertingConfig = self.alerting_config

        _slack_alerting_logger = SlackAlerting(
            alerting_threshold=router_alerting_config.alerting_threshold,
            alerting=["slack"],
            default_webhook_url=router_alerting_config.webhook_url,
        )

        litellm.callbacks.append(_slack_alerting_logger)
        litellm.success_callback.append(
            _slack_alerting_logger.response_taking_too_long_callback
        )
        print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n")  # noqa

    def send_deployment_cooldown_alert(
        self,
        deployment_id: str,
        exception_status: Union[str, int],
        cooldown_time: float,
    ):
        try:
            from litellm.proxy.proxy_server import proxy_logging_obj

            # trigger slack alert saying deployment is in cooldown
            if (
                proxy_logging_obj is not None
                and proxy_logging_obj.alerting is not None
                and "slack" in proxy_logging_obj.alerting
            ):
                _deployment = self.get_deployment(model_id=deployment_id)
                if _deployment is None:
                    return

                _litellm_params = _deployment["litellm_params"]
                temp_litellm_params = copy.deepcopy(_litellm_params)
                temp_litellm_params = dict(temp_litellm_params)
                _model_name = _deployment.get("model_name", None)
                _api_base = litellm.get_api_base(
                    model=_model_name, optional_params=temp_litellm_params
                )
                asyncio.create_task(
                    proxy_logging_obj.slack_alerting_instance.send_alert(
                        message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
                        alert_type="cooldown_deployment",
                        level="Low",
                    )
                )
        except Exception as e:
            pass

    def flush_cache(self):
        litellm.cache = None
        self.cache.flush_cache()

    def reset(self):
        ## clean up on close
        litellm.success_callback = []
        litellm.__async_success_callback = []
        litellm.failure_callback = []
        litellm._async_failure_callback = []
        self.retry_policy = None
        self.flush_cache()
